{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU",
    "gpuClass": "standard"
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lwSKH5kZsLP5"
      },
      "outputs": [],
      "source": [
        "# Install latest bitsandbytes & transformers, accelerate from source\n",
        "!pip install -q -U bitsandbytes\n",
        "!pip install -q -U git+https://github.com/huggingface/transformers.git\n",
        "!pip install -q -U git+https://github.com/huggingface/peft.git\n",
        "!pip install -q -U git+https://github.com/huggingface/accelerate.git\n",
        "# Other requirements for the demo\n",
        "!pip install gradio\n",
        "!pip install sentencepiece"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Load the model.\n",
        "# Note: It can take a while to download LLaMA and add the adapter modules.\n",
        "# You can also use the 13B model by loading in 4bits.\n",
        "\n",
        "import torch\n",
        "from peft import PeftModel    \n",
        "from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer\n",
        "\n",
        "model_name = \"decapoda-research/llama-7b-hf\"\n",
        "adapters_name = 'timdettmers/guanaco-7b'\n",
        "\n",
        "print(f\"Starting to load the model {model_name} into memory\")\n",
        "\n",
        "m = AutoModelForCausalLM.from_pretrained(\n",
        "    model_name,\n",
        "    #load_in_4bit=True,\n",
        "    torch_dtype=torch.bfloat16,\n",
        "    device_map={\"\": 0}\n",
        ")\n",
        "m = PeftModel.from_pretrained(m, adapters_name)\n",
        "m = m.merge_and_unload()\n",
        "tok = LlamaTokenizer.from_pretrained(model_name)\n",
        "tok.bos_token_id = 1\n",
        "\n",
        "stop_token_ids = [0]\n",
        "\n",
        "print(f\"Successfully loaded the model {model_name} into memory\")"
      ],
      "metadata": {
        "id": "2QK51MtdsMLu"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Setup the gradio Demo.\n",
        "\n",
        "import datetime\n",
        "import os\n",
        "from threading import Event, Thread\n",
        "from uuid import uuid4\n",
        "\n",
        "import gradio as gr\n",
        "import requests\n",
        "\n",
        "max_new_tokens = 1536\n",
        "start_message = \"\"\"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\"\"\"\n",
        "\n",
        "class StopOnTokens(StoppingCriteria):\n",
        "    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:\n",
        "        for stop_id in stop_token_ids:\n",
        "            if input_ids[0][-1] == stop_id:\n",
        "                return True\n",
        "        return False\n",
        "\n",
        "\n",
        "def convert_history_to_text(history):\n",
        "    text = start_message + \"\".join(\n",
        "        [\n",
        "            \"\".join(\n",
        "                [\n",
        "                    f\"### Human: {item[0]}\\n\",\n",
        "                    f\"### Assistant: {item[1]}\\n\",\n",
        "                ]\n",
        "            )\n",
        "            for item in history[:-1]\n",
        "        ]\n",
        "    )\n",
        "    text += \"\".join(\n",
        "        [\n",
        "            \"\".join(\n",
        "                [\n",
        "                    f\"### Human: {history[-1][0]}\\n\",\n",
        "                    f\"### Assistant: {history[-1][1]}\\n\",\n",
        "                ]\n",
        "            )\n",
        "        ]\n",
        "    )\n",
        "    return text\n",
        "\n",
        "\n",
        "def log_conversation(conversation_id, history, messages, generate_kwargs):\n",
        "    logging_url = os.getenv(\"LOGGING_URL\", None)\n",
        "    if logging_url is None:\n",
        "        return\n",
        "\n",
        "    timestamp = datetime.datetime.now().strftime(\"%Y-%m-%dT%H:%M:%S\")\n",
        "\n",
        "    data = {\n",
        "        \"conversation_id\": conversation_id,\n",
        "        \"timestamp\": timestamp,\n",
        "        \"history\": history,\n",
        "        \"messages\": messages,\n",
        "        \"generate_kwargs\": generate_kwargs,\n",
        "    }\n",
        "\n",
        "    try:\n",
        "        requests.post(logging_url, json=data)\n",
        "    except requests.exceptions.RequestException as e:\n",
        "        print(f\"Error logging conversation: {e}\")\n",
        "\n",
        "\n",
        "def user(message, history):\n",
        "    # Append the user's message to the conversation history\n",
        "    return \"\", history + [[message, \"\"]]\n",
        "\n",
        "\n",
        "def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id):\n",
        "    print(f\"history: {history}\")\n",
        "    # Initialize a StopOnTokens object\n",
        "    stop = StopOnTokens()\n",
        "\n",
        "    # Construct the input message string for the model by concatenating the current system message and conversation history\n",
        "    messages = convert_history_to_text(history)\n",
        "\n",
        "    # Tokenize the messages string\n",
        "    input_ids = tok(messages, return_tensors=\"pt\").input_ids\n",
        "    input_ids = input_ids.to(m.device)\n",
        "    streamer = TextIteratorStreamer(tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True)\n",
        "    generate_kwargs = dict(\n",
        "        input_ids=input_ids,\n",
        "        max_new_tokens=max_new_tokens,\n",
        "        temperature=temperature,\n",
        "        do_sample=temperature > 0.0,\n",
        "        top_p=top_p,\n",
        "        top_k=top_k,\n",
        "        repetition_penalty=repetition_penalty,\n",
        "        streamer=streamer,\n",
        "        stopping_criteria=StoppingCriteriaList([stop]),\n",
        "    )\n",
        "\n",
        "    stream_complete = Event()\n",
        "\n",
        "    def generate_and_signal_complete():\n",
        "        m.generate(**generate_kwargs)\n",
        "        stream_complete.set()\n",
        "\n",
        "    def log_after_stream_complete():\n",
        "        stream_complete.wait()\n",
        "        log_conversation(\n",
        "            conversation_id,\n",
        "            history,\n",
        "            messages,\n",
        "            {\n",
        "                \"top_k\": top_k,\n",
        "                \"top_p\": top_p,\n",
        "                \"temperature\": temperature,\n",
        "                \"repetition_penalty\": repetition_penalty,\n",
        "            },\n",
        "        )\n",
        "\n",
        "    t1 = Thread(target=generate_and_signal_complete)\n",
        "    t1.start()\n",
        "\n",
        "    t2 = Thread(target=log_after_stream_complete)\n",
        "    t2.start()\n",
        "\n",
        "    # Initialize an empty string to store the generated text\n",
        "    partial_text = \"\"\n",
        "    for new_text in streamer:\n",
        "        partial_text += new_text\n",
        "        history[-1][1] = partial_text\n",
        "        yield history\n",
        "\n",
        "\n",
        "def get_uuid():\n",
        "    return str(uuid4())\n",
        "\n",
        "\n",
        "with gr.Blocks(\n",
        "    theme=gr.themes.Soft(),\n",
        "    css=\".disclaimer {font-variant-caps: all-small-caps;}\",\n",
        ") as demo:\n",
        "    conversation_id = gr.State(get_uuid)\n",
        "    gr.Markdown(\n",
        "        \"\"\"<h1><center>Guanaco Demo</center></h1>\n",
        "\"\"\"\n",
        "    )\n",
        "    chatbot = gr.Chatbot().style(height=500)\n",
        "    with gr.Row():\n",
        "        with gr.Column():\n",
        "            msg = gr.Textbox(\n",
        "                label=\"Chat Message Box\",\n",
        "                placeholder=\"Chat Message Box\",\n",
        "                show_label=False,\n",
        "            ).style(container=False)\n",
        "        with gr.Column():\n",
        "            with gr.Row():\n",
        "                submit = gr.Button(\"Submit\")\n",
        "                stop = gr.Button(\"Stop\")\n",
        "                clear = gr.Button(\"Clear\")\n",
        "    with gr.Row():\n",
        "        with gr.Accordion(\"Advanced Options:\", open=False):\n",
        "            with gr.Row():\n",
        "                with gr.Column():\n",
        "                    with gr.Row():\n",
        "                        temperature = gr.Slider(\n",
        "                            label=\"Temperature\",\n",
        "                            value=0.7,\n",
        "                            minimum=0.0,\n",
        "                            maximum=1.0,\n",
        "                            step=0.1,\n",
        "                            interactive=True,\n",
        "                            info=\"Higher values produce more diverse outputs\",\n",
        "                        )\n",
        "                with gr.Column():\n",
        "                    with gr.Row():\n",
        "                        top_p = gr.Slider(\n",
        "                            label=\"Top-p (nucleus sampling)\",\n",
        "                            value=0.9,\n",
        "                            minimum=0.0,\n",
        "                            maximum=1,\n",
        "                            step=0.01,\n",
        "                            interactive=True,\n",
        "                            info=(\n",
        "                                \"Sample from the smallest possible set of tokens whose cumulative probability \"\n",
        "                                \"exceeds top_p. Set to 1 to disable and sample from all tokens.\"\n",
        "                            ),\n",
        "                        )\n",
        "                with gr.Column():\n",
        "                    with gr.Row():\n",
        "                        top_k = gr.Slider(\n",
        "                            label=\"Top-k\",\n",
        "                            value=0,\n",
        "                            minimum=0.0,\n",
        "                            maximum=200,\n",
        "                            step=1,\n",
        "                            interactive=True,\n",
        "                            info=\"Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.\",\n",
        "                        )\n",
        "                with gr.Column():\n",
        "                    with gr.Row():\n",
        "                        repetition_penalty = gr.Slider(\n",
        "                            label=\"Repetition Penalty\",\n",
        "                            value=1.0,\n",
        "                            minimum=1.0,\n",
        "                            maximum=2.0,\n",
        "                            step=0.1,\n",
        "                            interactive=True,\n",
        "                            info=\"Penalize repetition — 1.0 to disable.\",\n",
        "                        )\n",
        "    with gr.Row():\n",
        "        gr.Markdown(\n",
        "            \"Disclaimer: The model can produce factually incorrect output, and should not be relied on to produce \"\n",
        "            \"factually accurate information. The model was trained on various public datasets; while great efforts \"\n",
        "            \"have been taken to clean the pretraining data, it is possible that this model could generate lewd, \"\n",
        "            \"biased, or otherwise offensive outputs.\",\n",
        "            elem_classes=[\"disclaimer\"],\n",
        "        )\n",
        "    with gr.Row():\n",
        "        gr.Markdown(\n",
        "            \"[Privacy policy](https://gist.github.com/samhavens/c29c68cdcd420a9aa0202d0839876dac)\",\n",
        "            elem_classes=[\"disclaimer\"],\n",
        "        )\n",
        "\n",
        "    submit_event = msg.submit(\n",
        "        fn=user,\n",
        "        inputs=[msg, chatbot],\n",
        "        outputs=[msg, chatbot],\n",
        "        queue=False,\n",
        "    ).then(\n",
        "        fn=bot,\n",
        "        inputs=[\n",
        "            chatbot,\n",
        "            temperature,\n",
        "            top_p,\n",
        "            top_k,\n",
        "            repetition_penalty,\n",
        "            conversation_id,\n",
        "        ],\n",
        "        outputs=chatbot,\n",
        "        queue=True,\n",
        "    )\n",
        "    submit_click_event = submit.click(\n",
        "        fn=user,\n",
        "        inputs=[msg, chatbot],\n",
        "        outputs=[msg, chatbot],\n",
        "        queue=False,\n",
        "    ).then(\n",
        "        fn=bot,\n",
        "        inputs=[\n",
        "            chatbot,\n",
        "            temperature,\n",
        "            top_p,\n",
        "            top_k,\n",
        "            repetition_penalty,\n",
        "            conversation_id,\n",
        "        ],\n",
        "        outputs=chatbot,\n",
        "        queue=True,\n",
        "    )\n",
        "    stop.click(\n",
        "        fn=None,\n",
        "        inputs=None,\n",
        "        outputs=None,\n",
        "        cancels=[submit_event, submit_click_event],\n",
        "        queue=False,\n",
        "    )\n",
        "    clear.click(lambda: None, None, chatbot, queue=False)\n",
        "\n",
        "demo.queue(max_size=128, concurrency_count=2)\n"
      ],
      "metadata": {
        "id": "aklTR-es2bma"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Launch your Guanaco Demo!\n",
        "demo.launch()"
      ],
      "metadata": {
        "id": "e0nzyqUks49E"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "S3Iq8VC6s7I5"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}